#!/usr/bin/python2 -S
# -*- coding: utf-8 -*-
import sys
reload(sys)
sys.setdefaultencoding("utf-8")

import site
import traceback
import codecs
import os
import sqlite3
#import pprint
import subprocess

import his_loader
import db_operator
import analyzer
import data_struct 

MY_NAME='dla'

# 工具入口

def print_usage( argv0 = MY_NAME ):
    print "Usage:"
    print "  %s list        列出导入单个日线文件" % argv0
    print "  %s gen_alpha   <代码>  以<代码>为基准，设定其他代码的alpha" % argv0
    print "  %s correl      <代码1> <代码2> 考察指定两个代码的相关性 " % argv0
    print "  %s correl_all  [<代码列表文件>] 考察所有代码之间的相关性 " % argv0
    print "  %s show_correl 列出记录与DB的相关性关系 " % argv0
    print "  %s faster_horse    <代码1> <代码2> [短周期=1] [中周期=20] [开始] [结束] 回测‘换快马’策略" % argv0
    print "  %s faster_horse2   <代码1>,<代码2>[,<代码3> ...] [MA周期=1] [门槛=1] [开始] [结束] 回测‘换多马’策略" % argv0
    print "  %s probe_faster_horse   <代码1> <代码2> 所有‘换快马’策略的所有参数组合" % argv0
    print "  %s rotate  <代码1>,<代码2>[,<代码3> ...] [最多持仓数=5] [开始] [结束] 回测‘轮换’策略" % argv0
    print "  %s import      <日线文件>   导入单个日线文件" % argv0
    print "  %s importdir   <日线目录>   导入指定目录下所有日线文件(.txt) " % argv0


# 处理 'import' 子命令
def handle_import( argv, argv0 = MY_NAME ): 
    
    i = len(argv)
    if ( i < 1 ):
        print "  %s import    <日线文件>   导入单个日线文件" % argv0
        print "\n 未指定 <日报文件>\n"
        return 1

    try:
        # connect to DB 
        conn = db_operator.get_db_conn()
        dbcur = conn.cursor()
        inventory_ranges = db_operator.get_inventory(dbcur)

        # real stuff
        his_loader.load_some( argv[0] , dbcur, inventory_ranges)
        
        # DB clean up
        conn.commit()

    except  Exception as e:
        #pp = pprint.PrettyPrinter(indent=4)
        #pp.pprint(e) 
        
        (t, v, bt) = sys.exc_info()
        traceback.print_exception(t, v, bt)

        print
        print e
        return 1
    finally:
        dbcur.close()
        conn.close()


    return 0

# 处理 'importdir' 子命令
def handle_importdir( argv, argv0 = MY_NAME ):
    i = len(argv)
    if ( i < 1 ):
        print "  %s importdir <日线目录>   导入指定目录下所有日线文件(.txt)" % argv0
        print "\n 未指定 <日线目录>\n"
        return 1

    myself = sys.argv[0]
    r = subprocess.call([ "find"
            , argv[0]
            , "-name"
            , "*.txt"
            , "-exec"
            , "./mac_csv_helper"
            , "{}"
            , ";"
            ])
    return r 

# 处理 'list' 子命令
def handle_list( argv, argv0 = MY_NAME ): 
    try:
        # connect to DB 
        conn = db_operator.get_db_conn()
        dbcur = conn.cursor()
        inventory_ranges = db_operator.get_inventory(dbcur)

        # real stuff
        analyzer.print_inventory( inventory_ranges)
        
        # DB clean up
        conn.commit()

    except  Exception as e:
        #pp = pprint.PrettyPrinter(indent=4)
        #pp.pprint(e) 
        
        (t, v, bt) = sys.exc_info()
        traceback.print_exception(t, v, bt)

        print
        print e
        return 1 
    finally:
        dbcur.close()
        conn.close()

    return 0

# 处理 'show_correl' 子命令
def handle_show_correl( argv, argv0 = MY_NAME ): 
    try:
        # real stuff
        db_operator.show_correl()
        
    except  Exception as e:
        #pp = pprint.PrettyPrinter(indent=4)
        #pp.pprint(e) 
        
        (t, v, bt) = sys.exc_info()
        traceback.print_exception(t, v, bt)

        print
        print e
        return 1 

    return 0



# 处理 'correl_all' 子命令
def handle_correl_all ( argv, argv0 = MY_NAME ): 
    try:
        # connect to DB 
        conn = db_operator.get_db_conn()
        dbcur = conn.cursor()


        inventory_ranges = {} 
 
        i = len(argv)
        if ( 1 == i  ):
            inventory_ranges = db_operator.get_inventory(dbcur)
            # 只关心列在文件里的那些代码
            codes_from_file = his_loader.codes_from_file( argv[0] )

            for k in inventory_ranges.keys():
                if k not in codes_from_file:
                    del inventory_ranges[k]

            analyzer.print_inventory( inventory_ranges)
        else:
            inventory_ranges = db_operator.get_inventory2(dbcur)
        
        # real stuff 
        analyzer.correlation_all( inventory_ranges,  dbcur)
        
               
        # DB clean up
        conn.commit()

    except  Exception as e:
        #pp = pprint.PrettyPrinter(indent=4)
        #pp.pprint(e) 
        
        (t, v, bt) = sys.exc_info()
        traceback.print_exception(t, v, bt)

        print
        print e

        return 1
    finally:
        dbcur.close()
        conn.close()


    return 0


# 处理 'correl' 子命令
def handle_correl( argv, argv0 = MY_NAME ): 
    
    i = len(argv)
    if ( i < 2 ):
        print "  %s correl    <代码1> <代码2> 考察指定两个代码的相关性 " % argv0
        print "\n 未指定两个代码\n"
        return 1

    try:
        # connect to DB 
        conn = db_operator.get_db_conn()
        dbcur = conn.cursor()
        inventory_ranges = db_operator.get_inventory(dbcur)

        # real stuff
        code1 = argv[0]
        code2 = argv[1]

        if code1 not in inventory_ranges:
            print "%s not exists." % code1
            return 1 
        
        if code2 not in inventory_ranges:
            print "%s not exists." % code2
            return 1

        sec1 = inventory_ranges[code1]
        sec2 = inventory_ranges[code2]

        r_close,r_delta, num = analyzer.correlation(dbcur, sec1, sec2)
        print "%s %s , 收盘价关联度 = %f，涨跌幅关联度 = %f" % ( str(sec1), str(sec2), r_close, r_delta )
        
        # DB clean up
        conn.commit()

    except  Exception as e:
        #pp = pprint.PrettyPrinter(indent=4)
        #pp.pprint(e) 
        
        (t, v, bt) = sys.exc_info()
        traceback.print_exception(t, v, bt)

        print
        print e

        return 1
    finally:
        dbcur.close()
        conn.close()


    return 0

# 处理 'faster_horse_all' 子命令
def handle_faster_horse_all ( argv, argv0 = MY_NAME ): 
    try:
        # connect to DB 
        conn = db_operator.get_db_conn()
        dbcur = conn.cursor()


        inventory_ranges = {} 
 
        i = len(argv)
        if (  i >= 1 ):
            inventory_ranges = db_operator.get_inventory(dbcur)
            # 只关心列在文件里的那些代码
            codes_from_file = his_loader.codes_from_file( argv[0] )

            for k in inventory_ranges.keys():
                if k not in codes_from_file:
                    del inventory_ranges[k]

            analyzer.print_inventory( inventory_ranges)
        else:
            print "  %s faster_horse_all  <代码列表> " % argv0
            return 1

        
        # real stuff 
        analyzer.faster_horse_all( inventory_ranges,  dbcur)
        
               
        # DB clean up
        conn.commit()

    except  Exception as e:
        (t, v, bt) = sys.exc_info()
        traceback.print_exception(t, v, bt)
        print
        print e

        return 1
    finally:
        dbcur.close()
        conn.close()


    return 0


def handle_probe_faster_horse( argv, argv0 = MY_NAME ): 
    
    i = len(argv)
    if ( i < 2 ):
        print "  %s bt_faster_horse   <代码1> <代码2>  对指定两个代码进行‘换马’策略的回测 " % argv0
        print "\n 未指定两个代码\n"
        return 1

    MA_Size1 = 1
    MA_Size2 = 20
    
    try:
        # connect to DB 
        conn = db_operator.get_db_conn()
        dbcur = conn.cursor()
        inventory_ranges = db_operator.get_inventory(dbcur)

        # real stuff
        code1 = argv[0]
        code2 = argv[1]

        if code1 not in inventory_ranges:
            print "%s not exists." % code1
            return 1 
        
        if code2 not in inventory_ranges:
            print "%s not exists." % code2
            return 1

        sec1 = inventory_ranges[code1]
        sec2 = inventory_ranges[code2]

        for MA_Size1 in range(1, 11):
            for MA_Size2 in range( 2* MA_Size1, 41 ):

                net_value, day_num, t_num, leg1 ,leg2 = analyzer.bt_faster_horse(
                        dbcur, sec1, sec2, MA_Size1, MA_Size2 
                        )
                print "%s %s ‘换快马(%d-%d)’, %d 交易日，净值 %f, 换(建)仓 %d 次, 同期两脚净值 %f,%f"  % (
                            str(sec1), str(sec2), MA_Size1 ,MA_Size2 
                            , day_num, net_value, t_num
                            , leg1, leg2
                            )
        
        # DB clean up
        conn.commit()

    except  Exception as e:
        #pp = pprint.PrettyPrinter(indent=4)
        #pp.pprint(e) 
        
        (t, v, bt) = sys.exc_info()
        traceback.print_exception(t, v, bt)

        print
        print e

        return 1
    finally:
        dbcur.close()
        conn.close()


    return 0


# 处理 'bt_faster_horse' 子命令 , ‘换马’策略的回测
def handle_bt_faster_horse( argv, argv0 = MY_NAME ): 
    
    i = len(argv)
    if ( i < 2 ):
        print "  %s bt_faster_horse   <代码1> <代码2> [短周期=1] [中周期=20] 对指定两个代码进行‘换马’策略的回测 " % argv0
        print "\n 未指定两个代码\n"
        return 1

    MA_Size1 = 1
    MA_Size2 = 20

    start_day = ""
    end_day   = ""

    if len(argv) >= 4 :
        try:
            MA_Size1 =  int ( argv[2])
            MA_Size2 =  int ( argv[3])

            if MA_Size1 < 0 or MA_Size2 <=0 or MA_Size1 >= MA_Size2:
                print "短期MA天数 必须小于 中期MA天数， 且都必须是正整数。"
                return 2
        except ValueError:
            print "短期MA天数/中期MA天数， 都必须是正整数。"
            return 3

    if  len(argv) >= 5:
        start_day =argv[4]
        if  not data_struct.is_yyyy_mm_dd( start_day ):
            print "开始日期必须是 yyyy-mm-dd"
            return 2

    if len(argv) >= 6:
        end_day =argv[5]
        if  not data_struct.is_yyyy_mm_dd( end_day ):
            print "结束日期必须是 yyyy-mm-dd"
            return 2
    
    try:
        # connect to DB 
        conn = db_operator.get_db_conn()
        dbcur = conn.cursor()
        inventory_ranges = db_operator.get_inventory(dbcur)

        # real stuff
        code1 = argv[0]
        code2 = argv[1]

        if code1 not in inventory_ranges:
            print "%s not exists." % code1
            return 1 
        
        if code2 not in inventory_ranges:
            print "%s not exists." % code2
            return 1

        sec1 = inventory_ranges[code1]
        sec2 = inventory_ranges[code2]

        net_value, day_num, t_num, leg1 ,leg2 = analyzer.bt_faster_horse(
                 dbcur, sec1, sec2, MA_Size1, MA_Size2 , start_day, end_day 
                 )

        scope = ""
        if "" != start_day:
            scope = " %s~%s" % (start_day, end_day)

        print "%s %s ‘换快马(%d-%d)%s’, %d T，V %f, 交易 %d 次, 同期两脚V %f/%f, 业绩%f "  % (
                     str(sec1), str(sec2), MA_Size1 ,MA_Size2 ,scope 
                     , day_num, net_value, t_num
                     , leg1, leg2
                     , net_value - ( leg1 + leg2 ) / 2
                     )
        
        # DB clean up
        conn.commit()

    except  Exception as e:
        #pp = pprint.PrettyPrinter(indent=4)
        #pp.pprint(e) 
        
        (t, v, bt) = sys.exc_info()
        traceback.print_exception(t, v, bt)

        print
        print e

        return 1
    finally:
        dbcur.close()
        conn.close()


    return 0

# 处理 'bt_faster_horse' 子命令 , ‘换多马’策略的回测
def handle_faster_horse2( argv, argv0 = MY_NAME ): 
    
    i = len(argv)
    if ( i < 1 ):
        print "  %s faster_horse2   <代码1>,<代码2>[,<代码3> ...] [MA周期=1] [门槛]  [开始日] [结束日] 回测‘换多马’策略" % argv0
        return 1

    MA_Size1 = 1
    switch_threshold = 1

    start_day = ""
    end_day   = ""

    if len(argv) >= 2 :
        try:
            MA_Size1 =  int ( argv[1])

            if MA_Size1 < 0 :
                print "MA天数 必须是正整数。"
                return 2
        except ValueError:
            print "短期MA天数/中期MA天数， 都必须是正整数。"
            return 3
 
    if len(argv) >= 3 :
        try:
            switch_threshold  =  int ( argv[2])

            if switch_threshold  < 1 :
                print "换仓门槛不能小于1。"
                return 2
        except ValueError:
            print "换仓门槛不能小于1。"
            return 3

    if  len(argv) >= 4:
        start_day =argv[3]
        if  not data_struct.is_yyyy_mm_dd( start_day ):
            print "开始日期必须是 yyyy-mm-dd"
            return 2

    if len(argv) >= 5:
        end_day =argv[4]
        if  not data_struct.is_yyyy_mm_dd( end_day ):
            print "结束日期必须是 yyyy-mm-dd"
            return 2
    
    try:
        # connect to DB 
        conn = db_operator.get_db_conn()
        dbcur = conn.cursor()
        inventory_ranges = db_operator.get_inventory(dbcur)

        # real stuff
        codes = argv[0].split(',')

        secs = []
        for c in codes:
            if c  not in inventory_ranges:
                print "%s not exists." % c
                return 1
            secs.append ( inventory_ranges[c])
        

        net_value, day_num, t_num, legs =  analyzer.faster_horse2(
                 dbcur, secs, MA_Size1, switch_threshold , start_day, end_day 
                 )

        scope = ""
        if "" != start_day:
            scope = " %s~%s" % (start_day, end_day)
        
        s = ""
        for sec in secs:
            s = s + "%s " % sec

        print s


        print "‘换快马%s’, %d T，V %f, 交易 %d 次, 同期各脚 %s, 业绩%f "  % (
                     scope 
                     , day_num, net_value, t_num
                     , str(legs)
                     , net_value - sum( legs ) / len(legs)
                     )
       
        # DB clean up
        conn.commit()

    except  Exception as e:
        #pp = pprint.PrettyPrinter(indent=4)
        #pp.pprint(e) 
        
        (t, v, bt) = sys.exc_info()
        traceback.print_exception(t, v, bt)

        print
        print e

        return 1
    finally:
        dbcur.close()
        conn.close()


    return 0

# 处理 'rotate' 子命令 , ‘轮换’策略的回测
def handle_rotate( argv, argv0 = MY_NAME ): 
    
    i = len(argv)
    if ( i < 1 ):
        print "  %s rotate  <代码1>,<代码2>[,<代码3> ...] [最多持仓品种=1] [开始日] [结束日] 回测‘换多马’策略" % argv0
        return 1

    max_hold = 1

    start_day = ""
    end_day   = ""

    if len(argv) >= 2 :
        try:
            max_hold =  int ( argv[1])

            if max_hold  < 1 :
                print "最多持仓品种数必须>1 。"
                return 2
        except ValueError:
            print "最多持仓品种数必须>1 。"
            return 3
 
    if  len(argv) >= 3:
        start_day =argv[2]
        if  not data_struct.is_yyyy_mm_dd( start_day ):
            print "开始日期必须是 yyyy-mm-dd"
            return 2

    if len(argv) >= 4:
        end_day =argv[3]
        if  not data_struct.is_yyyy_mm_dd( end_day ):
            print "结束日期必须是 yyyy-mm-dd"
            return 2
    
    try:
        # connect to DB 
        conn = db_operator.get_db_conn()
        dbcur = conn.cursor()
        inventory_ranges = db_operator.get_inventory(dbcur)

        # real stuff
        codes = argv[0].split(',')

        secs = []
        for c in codes:
            if c  not in inventory_ranges:
                print "%s not exists." % c
                return 1
            secs.append ( inventory_ranges[c])
        

        net_value, day_num, t_num, legs =  analyzer.rotate(
                 dbcur, secs, max_hold, start_day, end_day 
                 )

        scope = ""
        if "" != start_day:
            scope = " %s~%s" % (start_day, end_day)
        
        s = ""
        for sec in secs:
            s = s + "%s " % sec

        print s


        print "‘轮换%s’, %d T，V %f, 交易 %d 次, 同期各脚 %s, 业绩%f "  % (
                     scope 
                     , day_num, net_value, t_num
                     , str(legs)
                     , net_value - sum( legs ) / len(legs)
                     )
       
        # DB clean up
        conn.commit()

    except  Exception as e:
        #pp = pprint.PrettyPrinter(indent=4)
        #pp.pprint(e) 
        
        (t, v, bt) = sys.exc_info()
        traceback.print_exception(t, v, bt)

        print
        print e

        return 1
    finally:
        dbcur.close()
        conn.close()


    return 0



# 处理 'gen_alpha' 子命令
def handle_gen_alpha( argv, argv0 = MY_NAME ): 
    
    i = len(argv)
    if ( i < 1 ):
        print "  %s gen_alpha  <代码>  以<代码>为基准，设定其他代码的alpha" % argv0
        print "\n 未指定<代码>\n"
        return 1

    try:
        # connect to DB 
        conn = db_operator.get_db_conn()
        dbcur = conn.cursor()
        inventory_ranges = db_operator.get_inventory(dbcur)

        # real stuff
        code1 = argv[0]
        analyzer.gen_alpha(dbcur, code1 )
        
        # DB clean up
        conn.commit()

    except  Exception as e:
        (t, v, bt) = sys.exc_info()
        traceback.print_exception(t, v, bt)

        print
        print e
        return 1
    finally:
        dbcur.close()
        conn.close()


    return 0


def make_sure_working_dir():
    r = subprocess.call([ "mkdir"
            , "-p"
            , data_struct.WORKING_DIR 
            ])

    if 0 != r:
        raise Exception("Failed to make '%s'" % data_struct.WORKING_DIR );

#    subprocess.call([ "ln "
#            , "-s"
#            , "%s/ggchart_loader.js" % data_struct.JSLIB 
#            , data_struct.WORKING_DIR 
#            ])

def main(): 
    i = len(sys.argv)
    if ( i < 2 ):
        print_usage()
        return 1

    make_sure_working_dir()

    sub_command = sys.argv[1]

    if ('import' == sub_command ):
        return handle_import ( sys.argv[2:] ) 
    elif ('importdir' == sub_command ):
        return handle_importdir ( sys.argv[2:] )  
    elif ('list' == sub_command ):
        return handle_list ( sys.argv[2:] )  
    elif ('show_correl' == sub_command ):
        return handle_show_correl ( sys.argv[2:] ) 
    elif ('correl' == sub_command ):
        return handle_correl ( sys.argv[2:] )  
    elif ('correl_all' == sub_command ):
        return handle_correl_all ( sys.argv[2:] )  
    elif ('gen_alpha' == sub_command ):
        return handle_gen_alpha ( sys.argv[2:] ) 
    elif ('faster_horse' == sub_command ):
        return handle_bt_faster_horse ( sys.argv[2:] )  
    elif ('faster_horse2' == sub_command ):
        return handle_faster_horse2 ( sys.argv[2:] )    
    elif ('rotate' == sub_command ):
        return handle_rotate ( sys.argv[2:] )  
    elif ('probe_faster_horse' == sub_command ):
        return handle_probe_faster_horse ( sys.argv[2:] ) 
    elif ('faster_horse_all' == sub_command ):
        return handle_faster_horse_all ( sys.argv[2:] ) 
    elif ('help' == sub_command ):
        print_usage()
        return 1
    else:
        print "\n无效的‘子命令’ -- %s\n" % ( sub_command, )
        print_usage()
        return 1

    return 0

if __name__ == "__main__":
    r = main()
    sys.exit(r)


